#####
# Script to generate figure 8. This script takes several hours to run on a machine with 60 cores and 256 GB RAM.
#####

# Function to generate trajectories to put into a stable basin
gen_trajs_stable_basin <- function(T=200, params, init_lo=c(110000,0), init_hi=c(0,1750000), opt=0, optimal_policy=NULL, ...) {
	init_lo[1] <- init_lo[1]/s_scale_factor
	init_lo[2] <- init_lo[2]/d_scale_factor
	init_hi[1] <- init_hi[1]/s_scale_factor
	init_hi[2] <- init_hi[2]/d_scale_factor

	opt_interp <- !is.null(optimal_policy)

	init_launchpath <- rep(0, length.out=T)
	if(opt==1&opt_interp==FALSE){ # nolint
		lp_lo <- optim(par = init_launchpath, fn = objective_fts_cpp, S = init_lo[1], D = init_lo[2], T=T, params = params, control = list(fnscale=-1), method = "L-BFGS-B", lower = 0, upper = 10)
		lp_hi <- optim(par = init_launchpath, fn = objective_fts_cpp, S = init_hi[1], D = init_hi[2], T=T, params = params, control = list(fnscale=-1), method = "L-BFGS-B", lower = 0, upper = 10)

		lp_lo_df <- data.frame(launches=lp_lo$par)
		lp_hi_df <- data.frame(launches=lp_hi$par)
	}
	if(opt==1&opt_interp==TRUE) { # nolint
		message("Generating interpolants for optimal policy...")
		grid_list_big <- list(S = unique(optimal_policy$S), D = unique(optimal_policy$D))
		opt_launch_rate_ipol <- ipol(optimal_policy$X, grid=grid_list_big, method="multilinear")

		lp_lo <- OA.ts(init_lo[1],init_lo[2],T,0,opt=1,constraint=0,params=params,opt_interpolant=opt_launch_rate_ipol,opt_policy_grid=grid_list_big,just_make_trajs=TRUE)
		lp_hi <- OA.ts(init_hi[1],init_hi[2],T,0,opt=1,constraint=0,params=params,opt_interpolant=opt_launch_rate_ipol,opt_policy_grid=grid_list_big,just_make_trajs=TRUE)

		lp_lo_df <- data.frame(launches=as.data.frame(lp_lo)$launches[-length(as.data.frame(lp_hi)$launches)])
		lp_hi_df <- data.frame(launches=as.data.frame(lp_hi)$launches[-length(as.data.frame(lp_hi)$launches)])
	}
	if(opt==0) {
		lp_lo <- OA.ts(init_lo[1],init_lo[2],T,0,opt=0,constraint=0,params=params)
		lp_hi <- OA.ts(init_hi[1],init_hi[2],T,0,opt=0,constraint=0,params=params)

		lp_lo_df <- data.frame(launches=as.data.frame(lp_lo)$launches[-length(as.data.frame(lp_hi)$launches)])
		lp_hi_df <- data.frame(launches=as.data.frame(lp_hi)$launches[-length(as.data.frame(lp_hi)$launches)])
	}

	lo <- cbind(time=seq(0:T),gen_fts(X=lp_lo_df$launches, S=init_lo[1], D=init_lo[2], T=T, params=params))
	hi <- cbind(time=seq(0:T),gen_fts(X=lp_hi_df$launches, S=init_hi[1], D=init_hi[2], T=T, params=params))

	fig_dfrm <- left_join( as.data.frame(lo), as.data.frame(hi), by=c("time"), suffix=c("_lo","_hi")) %>% 
				mutate(
					risk_lo = L(sats_lo, debs_lo, params), 
					risk_hi = L(sats_hi, debs_hi, params),
					launches_lo = launches_lo*s_scale_factor,
					launches_hi = launches_hi*s_scale_factor,
					sats_lo = sats_lo*s_scale_factor,
					sats_hi = sats_hi*s_scale_factor,
					debs_lo = debs_lo*d_scale_factor,
					debs_hi = debs_hi*d_scale_factor
					)

	return(fig_dfrm)
}

# Body of script

library(ggplot2)
library(cowplot)
library(patchwork)
library(viridis)
library(tidyverse)
library(pracma)
library(nleqslv)
library(rootSolve)
library(fields)
library(doParallel)
library(chebpol)
library(scales)
library(data.table)
library(tictoc)

library(simFnsEqns)

## Uncomment the following lines to accelerate computation if necessary
# enableJIT(3)
#system(sprintf("taskset -p 0xffffffffffffffffffff %d", Sys.getpid())) # Adjusts the R session's affinity mask from 1 to f, allowing the R process to use all cores.
source("equations.r")
source("simulation_functions.r")
source("planner_simulation.r")

p <- 1
discount <- 0.99
r <- ((1 - discount)/(discount))
fe_eqm <- 0.06
F <- p/(r+fe_eqm)
d <- 0.6
m <- 2
aSS <- 0.05
aSD <- 0.005
aDD <- 0.09 
bSS <- 10
bSD <- 5.375
bDD <- 1.3
Xbar <- 100
T <- 150

params <- data.frame(p=p, F=F, discount=discount, r=r, fe_eqm=fe_eqm, d=d, m=m, aSS=aSS, aSD=aSD, aDD=aDD, bSS=bSS, bSD=bSD, bDD=bDD, T=T, Xbar=1)
params$avg_sat_decay <- 1 # 1: satellites are infinitely lived

params2 <- params
params3 <- params

params2$fe_eqm <- 0.05*1.36
params2$discount <- 0.75
params2$r <- ((1 - params2$discount)/(params2$discount))
params2$F <- params2$p/(params2$r+params2$fe_eqm)

params3$fe_eqm <- 0.05*2.25
params3$F <- p/(r+params3$fe_eqm)

params$T <- 175
params7 <- params
params8 <- params
params9 <- params
params7$discount <- 0.99
params7$r <- ((1 - params7$discount)/(params7$discount))
params7$F <- params7$p/(params7$fe_eqm + params7$r)
params8$discount <- 0.55
params8$r <- ((1 - params8$discount)/(params8$discount))
params8$F <- params8$p/(params8$fe_eqm + params8$r)

params9$fe_eqm <- params3$fe_eqm
params9$discount <- 0.15
params9$r <- ((1 - params9$discount)/(params9$discount))
params9$F <- params9$p/(params9$fe_eqm + params9$r)

## Calibration scale factors.
s_scale_factor <- 116649.8/1.237508
d_scale_factor <- 367372.3/1.170672

tic()
# generate a satellite-debris grid. final version should have at least 500 points each
SD_grid_s <- make_sd_grid(0,2.5,40,0,10,40)
SD_grid <- make_sd_grid(0,2.5,400,0,10,400)
grid_list <- list(S = unique(SD_grid$S), D = unique(SD_grid$D)) # make a list of grid elements

# Solve for optimal policy on small grid
opt_policy_1 <- solve_opt_policy(params=params, small_guess_grid=SD_grid_s, larger_vfi_grid=SD_grid, grid_list=grid_list, ncores=60, label="stable-basin-1")
opt_policy_2 <- solve_opt_policy(params=params2, small_guess_grid=SD_grid_s, larger_vfi_grid=SD_grid, grid_list=grid_list, ncores=60, label="stable-basin-2", steps_guarantee=22)
opt_policy_3 <- solve_opt_policy(params=params9, small_guess_grid=SD_grid_s, larger_vfi_grid=SD_grid, grid_list=grid_list, ncores=60, label="stable-basin-3", steps_guarantee=100)

# Find stable basins
params$T <- 100 # Simulate trajectories for 100 periods to idnetify stability of initial condition
params2$T <- 100
params9$T <- 100

basins_results_1 <- find_stable_basins(params=params, optimal_policy=opt_policy_1, SD_grid=SD_grid, grid_list=grid_list, ncores=60, label="stable-basin-1")
basins_results_2 <- find_stable_basins(params=params2, optimal_policy=opt_policy_2, SD_grid=SD_grid, grid_list=grid_list, ncores=60, label="stable-basin-2")
basins_results_3 <- find_stable_basins(params=params9, optimal_policy=opt_policy_3, SD_grid=SD_grid, grid_list=grid_list, ncores=60, label="stable-basin-3-params9")

toc()


# Calculate trajectories in parallel. Store all values in a list, and then bind them to separate objects at the end.
tic()
cl <- makeCluster(6)
registerDoParallel(cl)

T_seq <- c(500, 500, 500, 500, 500, 500)
params_list_traj <- list(params, params2, params9, params, params2, params9)
opt_seq_traj <- c(1,1,1,0,0,0)
policy_list <- list(opt_policy_1, opt_policy_2, opt_policy_3, NULL, NULL, NULL)

opt_traj_list <- foreach(i=1:6, .inorder=TRUE, .export=ls(envir=globalenv()), .packages=c("chebpol","simFnsEqns", "tidyverse")) %dopar% {
	gen_trajs_stable_basin(T=T_seq[i], params=params_list_traj[[i]], opt=opt_seq_traj[i], optimal_policy=policy_list[[i]])
}

stopCluster(cl)
toc()

opt_traj_1 <- opt_traj_list[[1]]
opt_traj_2 <- opt_traj_list[[2]]
opt_traj_3 <- opt_traj_list[[3]]

oa_traj_1 <- opt_traj_list[[4]]
oa_traj_2 <- opt_traj_list[[5]]
oa_traj_3 <- opt_traj_list[[6]]

# Plot the results
fig_1_oa <- figure_10_gen(params, basins_results_1$oa, labels=0, FAR=0, make_nullclines=1, s_scale_factor = s_scale_factor, d_scale_factor = d_scale_factor) +
	geom_path(data=oa_traj_1[1:6,], aes(y=sats_lo,x=debs_lo), color="black", size=1.5, arrow=arrow()) +
	geom_path(data=oa_traj_1[1:6,], aes(y=sats_hi,x=debs_hi), linetype="dashed", color="black", size=1.5, arrow=arrow()) +
	geom_point(data=basins_results_1$oa[which(basins_results_1$oa$path_convergence>0),], aes(x=debris*d_scale_factor,y=satellites*s_scale_factor), alpha = 0) +
	xlim(c(0,250000)) + ylim(0,2.1*s_scale_factor) +
					scale_x_continuous(labels = function(x) comma(x)) +
					scale_y_continuous(labels = function(x) comma(x)) 

fig_3_oa <- figure_10_gen(params, basins_results_3$oa, labels=0, FAR=0, make_nullclines=1, s_scale_factor = s_scale_factor, d_scale_factor = d_scale_factor) +
	geom_path(data=oa_traj_3[1:5,], aes(y=sats_lo,x=debs_lo), color="black", size=1.5, arrow=arrow()) +
	geom_path(data=oa_traj_3[1:2,], aes(y=sats_hi,x=debs_hi), linetype="dashed", color="black", size=1.5, arrow=arrow()) +
	xlim(c(0,250000)) + ylim(0,2.1*s_scale_factor) +
					scale_x_continuous(labels = function(x) comma(x)) +
					scale_y_continuous(labels = function(x) comma(x)) 

fig_1_opt <- figure_10_gen(params, basins_results_1$opt, labels=0, FAR=0, make_nullclines=1, s_scale_factor = s_scale_factor, d_scale_factor = d_scale_factor) +
	geom_path(data=opt_traj_1[1:15,], aes(y=sats_lo,x=debs_lo), color="black", size=1.5, arrow=arrow()) +
	geom_path(data=opt_traj_1[1:15,], aes(y=sats_hi,x=debs_hi), linetype="dashed", color="black", size=1.5, arrow=arrow()) +
	xlim(c(0,250000)) + ylim(0,2.1*s_scale_factor) +
					scale_x_continuous(labels = function(x) comma(x)) +
					scale_y_continuous(labels = function(x) comma(x)) 

fig_3_opt <- figure_10_gen(params, basins_results_3$opt, labels=0, FAR=0, make_nullclines=1, s_scale_factor = s_scale_factor, d_scale_factor = d_scale_factor) +
	geom_path(data=opt_traj_3[1:15,], aes(y=sats_lo,x=debs_lo), color="black", size=1.5, arrow=arrow()) +
	geom_path(data=opt_traj_3[1:4,], aes(y=sats_hi,x=debs_hi), linetype="dashed", color="black", size=1.5, arrow=arrow()) +
	xlim(c(0,250000)) + ylim(0,2.1*s_scale_factor) +
					scale_x_continuous(labels = function(x) comma(x)) +
					scale_y_continuous(labels = function(x) comma(x)) 


ggsave(
	filename=paste0("../images/figure-8--kessler-syndrome-illustration.png"),
	plot=((fig_1_oa | fig_3_oa) / (fig_1_opt | fig_3_opt)) + plot_annotation(tag_levels='A'),
	width=12,
	height=10,
	units="in",
	dpi=300
)
